import numpy as np
import scipy
import sklearn
import matplotlib.pyplot as plt
import os

class GenerateData:

    def __init__(self, eyx,collecting_rule, Ydist = None, sigma=1,k=2,N=100, seed=None):
        """
        Initiate with parameters
        :param eyx: data generating p(y=1|x,a)
        :param collecting_rule: a function that determines the probability of action given covariates
        :param k: number of covariates
        :param N: number of observations
        """
        np.random.seed(seed)
        self.eyx = eyx
        self.sigma = sigma
        self.Ydist = Ydist
        self.collecting_rule = collecting_rule
        self.N = N
        self.k = k
        self.X = None
        self.A = None
        self.Ey = None
        self.Y = None

    def GenerateData(self):
        """
        :return: [X, A, Py, Y]
        """
        self.X = np.random.uniform(low=-1, high=1, size=self.N*self.k).reshape((self.N, self.k))
        self.A = np.random.binomial(1, self.collecting_rule(self.X))
        self.Ey = self.eyx(self.X, self.A).squeeze()
        if self.Ydist == 'normal':
            self.Y = np.random.normal(loc=self.Ey, scale=self.sigma)
        elif self.Ydist == 'logit':
            self.Y = np.random.binomial(n=1, p=self.Ey, size=self.N)
        return ([self.X, self.A, self.Ey, self.Y])

    def plot_Ey(self):
        if self.Ey is not None:
            plt.hist(self.Ey)
            plt.show()